{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "dfffabb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import os\n",
    "import numpy\n",
    "import pandas as pd\n",
    "import mindspore\n",
    "import mindspore.dataset as ds\n",
    "from d2l import mindspore as d2l\n",
    "\n",
    "#@save\n",
    "d2l.DATA_HUB['banana-detection'] = (\n",
    "    d2l.DATA_URL + 'banana-detection.zip',\n",
    "    '5de26c8fce5ccdea9f91267273464dc968d20d72')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "af4e71fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@save\n",
    "def read_data_bananas(is_train=True):\n",
    "    \"\"\"读取香蕉检测数据集中的图像和标签\"\"\"\n",
    "    data_dir = d2l.download_extract('banana-detection')\n",
    "    csv_fname = os.path.join(data_dir, 'bananas_train' if is_train\n",
    "                             else 'bananas_val', 'label.csv')\n",
    "    csv_data = pd.read_csv(csv_fname)\n",
    "    csv_data = csv_data.set_index('img_name')\n",
    "    images, targets = [], []\n",
    "    for img_name, target in csv_data.iterrows():\n",
    "        images.append(ds.vision.read_image(\n",
    "            os.path.join(data_dir, 'bananas_train' if is_train else\n",
    "                         'bananas_val', 'images', f'{img_name}')))\n",
    "        # 这里的target包含（类别，左上角x，左上角y，右下角x，右下角y），\n",
    "        # 其中所有图像都具有相同的香蕉类（索引为0）\n",
    "        targets.append(list(target))\n",
    "    return images, mindspore.Tensor(targets, dtype=mindspore.float32).unsqueeze(1) / 255"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ad494dda",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@save\n",
    "class BananasDataset():\n",
    "    \"\"\"一个用于加载香蕉检测数据集的自定义数据集\"\"\"\n",
    "    def __init__(self, is_train):\n",
    "        self.parent = None\n",
    "        self.features, self.labels = read_data_bananas(is_train)\n",
    "        print('read ' + str(len(self.features)) + (f' training examples' if\n",
    "              is_train else f' validation examples'))\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return (numpy.array(self.features[int(idx)], dtype=float), self.labels[int(idx)])\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ddf79d60",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@save\n",
    "def load_data_bananas(batch_size):\n",
    "    \"\"\"加载香蕉检测数据集\"\"\"\n",
    "    train_iter = ds.GeneratorDataset(source=BananasDataset(is_train=True), \n",
    "                                     column_names=['imgs', 'labels'], shuffle=True).batch(batch_size=batch_size)\n",
    "    val_iter = ds.GeneratorDataset(source=BananasDataset(is_train=False), \n",
    "                                     column_names=['imgs', 'labels'], shuffle=False).batch(batch_size=batch_size)\n",
    "    return train_iter, val_iter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "724ea48b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "read 1000 training examples\n"
     ]
    }
   ],
   "source": [
    "batch_size, edge_size = 32, 256\n",
    "train_iter, _ = load_data_bananas(batch_size)\n",
    "batch = next(train_iter.create_tuple_iterator())\n",
    "batch[0].shape, batch[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5da02d56",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bbox_to_rect(bbox, color):\n",
    "    \"\"\"Convert bounding box to matplotlib format.\n",
    "\n",
    "    Defined in :numref:`sec_bbox`\"\"\"\n",
    "    # Convert the bounding box (upper-left x, upper-left y, lower-right x,\n",
    "    # lower-right y) format to the matplotlib format: ((upper-left x,\n",
    "    # upper-left y), width, height)\n",
    "    return d2l.plt.Rectangle(\n",
    "        xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],\n",
    "        fill=False, edgecolor=color, linewidth=2)\n",
    "\n",
    "def show_bboxes(axes, bboxes, labels=None, colors=None):\n",
    "    \"\"\"Show bounding boxes.\n",
    "\n",
    "    Defined in :numref:`sec_anchor`\"\"\"\n",
    "\n",
    "    def make_list(obj, default_values=None):\n",
    "        if obj is None:\n",
    "            obj = default_values\n",
    "        elif not isinstance(obj, (list, tuple)):\n",
    "            obj = [obj]\n",
    "        return obj\n",
    "\n",
    "    labels = make_list(labels)\n",
    "    colors = make_list(colors, ['b', 'g', 'r', 'm', 'c'])\n",
    "    for i, bbox in enumerate(bboxes):\n",
    "        color = colors[i % len(colors)]\n",
    "        rect = bbox_to_rect(bbox.asnumpy(), color)\n",
    "        axes.add_patch(rect)\n",
    "        if labels and len(labels) > i:\n",
    "            text_color = 'k' if color == 'w' else 'w'\n",
    "            axes.text(rect.xy[0], rect.xy[1], labels[i],\n",
    "                      va='center', ha='center', fontsize=9, color=text_color,\n",
    "                      bbox=dict(facecolor=color, lw=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "902adf19",
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs = (batch[0][0:10]) / 255\n",
    "axes = d2l.show_images(imgs, 2, 5, scale=2)\n",
    "for ax, label in zip(axes, batch[1][0:10]):\n",
    "    show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a8f8d54",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96d591ab",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
